import torch
import matplotlib.pyplot as plt
from torchvision.io import read_image

image = read_image("data/PennFudanPed/PNGImages/FudanPed00046.png")
mask = read_image("data/PennFudanPed/PedMasks/FudanPed00046_mask.png")

print(f"image.shape:{image.shape}")
print(f"mask.shape:{mask.shape}")

plt.figure(figsize=(16,8))
plt.subplot(121)
plt.title("Image")
plt.imshow(image.permute(1,2,0))
# plt.imshow(image)
plt.subplot(122)
plt.title("Mask")
plt.imshow(mask.permute(1,2,0))

plt.show()